// Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.

// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), 
// to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 
// and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 
// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


#ifndef _RESHAPE_H
#define _RESHAPE_H

#include "operator.hpp"
#include "tensor_computing.h"
#include "tensor_desc.h"
#include "model_tools.h"

template <Arch A>
class Reshape: public Operator<A> {
public:
/**
 * @param shapeDims
 * @param axis
 * @param numAxes
 */
    Reshape(DataType dt, I32* shapeDimsPtr, I32 shapeSize, I32 axis, I32 numAxes){
        this->dt = dt;
        shapeDims = Vec<I32>(shapeSize);
        memcpy(this->shapeDims.data(), shapeDimsPtr, shapeSize * sizeof(I32));
        this->axis = axis;
        this->numAxes = numAxes;
        this->set_op_type(OT_Reshape);
    }

    void run() override
    {
        UTIL_TIME_TIC(__CLASS_FUNCTION__)

        Tensor inputTensor = this->inputTensors[0];
        TensorDesc inputDesc = inputTensor.get_desc();

        Tensor outputTensor = this->outputTensors[0];
        TensorDesc outputDesc = outputTensor.get_desc();

        CHECK_STATUS(reshape(inputDesc, inputTensor.get_val().get(), outputDesc, outputTensor.get_val().get(), A));

        UTIL_TIME_TOC(__CLASS_FUNCTION__)
    }

    EE infer_output_tensors_size(Vec<TensorDesc> inDims, Vec<TensorDesc>* outDims) override
    {
        TensorDesc inputDesc = inDims[0];
        CHECK_STATUS_WITH_RETURN(reshape_infer_output_size(inputDesc, &((*outDims)[0]), this->shapeDims.data(), this->shapeDims.size()));
        return SUCCESS;
    }


private:
    Vec<I32> shapeDims;
    I32 axis;
    I32 numAxes;
};

#endif //_RESHAPE_H
